import math

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from idspn import EnergyObjective, iDSPN, MSEObjective, clevr_project, ProjectSimplex
from dspn import DSPN
from slot import SlotAttention


def create_pairs(a, b=None):
    if b is None:
        b = a
    LA = a.size(1)
    LB = b.size(1)
    a = a.unsqueeze(2).expand(-1, -1, LB, -1)
    b = b.unsqueeze(1).expand(-1, LA, -1, -1)
    return a, b


class FSPool(nn.Module):
    """
        Simplified version of featurewise sort pooling, without the option of variable-size sets through masking. From:
        FSPool: Learning Set Representations with Featurewise Sort Pooling.
        Yan Zhang, Jonathon Hare, Adam Prügel-Bennett
        https://arxiv.org/abs/1906.02795
        https://github.com/Cyanogenoid/fspool
    """

    def __init__(self, in_channels, set_size):
        super().__init__()
        self.weight = nn.Parameter(torch.zeros(set_size, in_channels))
        nn.init.kaiming_normal_(self.weight, mode='fan_out', nonlinearity='linear')

    def forward(self, x):
        x, _ = x.sort(dim=1)
        x = torch.einsum('nlc, lc -> nc', x, self.weight)
        return x


class FSEncoder(nn.Module):
    def __init__(self, input_channels, dim, output_channels, set_size):
        super().__init__()
        self.d_in = input_channels
        self.mlp = nn.Sequential(
            nn.Linear(input_channels, dim),
            nn.ReLU(inplace=True),
            nn.Linear(dim, output_channels),
        )
        self.pool = FSPool(output_channels, set_size)

    def forward(self, x):
        assert x.size(-1) == self.d_in
        x = self.mlp(x)
        x = self.pool(x)
        return x



class SetToEnergy(nn.Module):
    def __init__(self, d_in, d_hid, set_size):
        super().__init__()
        self.enc = FSEncoder(d_in, d_hid, d_hid, set_size)
        self.mlp = nn.Sequential(
            nn.Linear(d_hid, d_hid),
            nn.ReLU(),
            nn.Linear(d_hid, 1, bias=False)
        )
    
    def forward(self, x):
        x = self.enc(x)
        e = self.mlp(x)
        return e.squeeze(1)


class RNFSEncoder(FSEncoder):
    def __init__(self, input_channels, dim, output_channels, set_size):
        super().__init__(2 * input_channels, dim, output_channels, set_size ** 2)

    def forward(self, x):
        x = torch.cat(create_pairs(x), dim=-1).flatten(1, 2)
        x = super().forward(x)
        return x


class ImageModel(nn.Module):
    """ ResNet18-based image encoder to turn an image into a feature vector """

    def __init__(self, latent, image_size):
        super().__init__()
        resnet = torchvision.models.resnet18()
        self.layers = nn.Sequential(*list(resnet.children())[:-2])
        resnet_output_dim = 512
        spatial_size = image_size // 32  # after resnet
        spatial_size = spatial_size // 2  # after strided conv
        self.end = nn.Sequential(
            nn.BatchNorm2d(resnet_output_dim),
            # now has 2x2 spatial size
            nn.Conv2d(resnet_output_dim, latent // spatial_size**2, 2, stride=2),
            # now has shape (n, latent // 4, 2, 2)
        )

    def forward(self, x):
        x = self.layers(x)
        x = self.end(x)
        return x.view(x.size(0), -1)


class DSPNModel(nn.Module):
    def __init__(self, input_dim, d_in, d_hid, d_latent, set_size, lr=1, iters=20, momentum=0.9, grad_clip=None, input_encoder='rnfs', decoder_encoder='fs', use_starting_set=False, image_input=False, image_size=None, implicit=False):
        super().__init__()
        self.lr = lr
        self.iters = iters
        self.implicit = implicit
        
        # self.proj_in = nn.Linear(input_dim, input_dim)
        x_dim = input_dim
        self.enc = FSEncoder(x_dim, d_hid, d_latent, set_size)

        decoder_set_encoder = FSEncoder(d_in+x_dim, d_hid, d_latent, set_size)

        if self.implicit:
            self.dspn = iDSPN(
                objective=MSEObjective(decoder_set_encoder, regularized=use_starting_set),
                optim_f=lambda p: torch.optim.SGD(p, lr=self.lr, momentum=momentum, nesterov=momentum > 0),
                optim_iters=self.iters,
                set_channels=d_in,
                set_size=set_size,
                grad_clip=grad_clip,
                use_starting_set=use_starting_set,
                # projection=ProjectSimplex.apply
                # projection=lambda x: torch.softmax(x, dim=-1)
            )
        else:
            self.dspn = DSPN(
                objective=MSEObjective(decoder_set_encoder, regularized=use_starting_set),
                set_channels=d_in,
                max_set_size=set_size,
                channels=d_latent,
                grad_clip=grad_clip,
                iters=self.iters,
                lr=self.lr,
                projection=ProjectSimplex.apply
            )

    def forward(self, x):
        # x = self.proj_in(x)
        z = self.enc(x)
        o, set_grad = self.dspn(z, x)
        return o, set_grad


class SlotAttentionModel(nn.Module):
    def __init__(self, d_in, d_hid, d_out, set_size):
        super().__init__()
        self.preproj = nn.Linear(d_in, d_hid)
        self.slot_attention = SlotAttention(set_size, d_hid, hidden_dim=d_hid, use_ln=True, iters=3)
        self.postproj = nn.Linear(d_hid, d_out)
    
    def forward(self, x):
        x = self.preproj(x)
        x = self.slot_attention(x)
        x = self.postproj(x)
        return x, None


class DSLayer(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.lin0 = nn.Linear(d_in, d_out)
        self.lin1 = nn.Linear(d_in, d_out)

    def forward(self, x):
        return self.lin0(x) + self.lin1(x.mean(1, keepdim=True))

class DSModel(nn.Module):
    def __init__(self, d_in, d_hid, d_out):
        super().__init__()
        self.ds = nn.Sequential(
            DSLayer(d_in, d_hid),
            nn.ReLU(),
            DSLayer(d_hid, d_out)
        )
    
    def forward(self, x):
        return self.ds(x), None


class LSTMModel(nn.Module):
    def __init__(self, d_in, d_hid, d_out, bidirectional=True):
        super().__init__()
        self.proj_in = nn.Linear(d_in, d_hid)
        self.rnn = nn.LSTM(
            input_size=d_hid,
            hidden_size=d_hid,
            num_layers=2,
            batch_first=True,
            bidirectional=bidirectional
        )
        self.proj_out = nn.Linear(2*d_hid if bidirectional else d_hid, d_out)
    
    def forward(self, x):
        x = self.proj_in(x)
        out, _ = self.rnn(x)
        out = self.proj_out(out)
        return out, None


class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x.transpose(0,1)
        x = x + self.pe[:x.size(0)]
        return self.dropout(x).transpose(0,1)


class TransformerModel(nn.Module):
    def __init__(self, d_in, d_hid, d_out, set_size):
        super().__init__()
        self.transformer = nn.Transformer(
            d_hid, nhead=8, num_encoder_layers=2, num_decoder_layers=2,
            dim_feedforward=2*d_hid, dropout=0., batch_first=True)
        self.pe = PositionalEncoding(d_hid, dropout=0.)
        self.proj_src = nn.Linear(d_in, d_hid)
        self.proj_tgt = nn.Linear(d_in, d_hid)
        self.proj_out = nn.Linear(d_hid, d_out)

    def forward(self, x):
        src = self.proj_src(x)
        tgt = self.pe(self.proj_tgt(x))
        x = self.transformer(src, tgt)
        x = self.proj_out(x)
        return x, None


class RandomTransformerModel(nn.Module):
    def __init__(self, d_in, d_hid, d_out, set_size):
        super().__init__()
        self.transformer = nn.Transformer(
            d_hid, nhead=8, num_encoder_layers=2, num_decoder_layers=2,
            dim_feedforward=2*d_hid, dropout=0., batch_first=True)
        self.proj_src = nn.Linear(d_in, d_hid)
        self.proj_tgt = nn.Linear(d_in, d_hid//2)
        self.proj_out = nn.Linear(d_hid, d_out)
        self.alpha = nn.Parameter(torch.zeros(1))
        self.beta = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        src = self.proj_src(x)
        tgt = self.proj_tgt(x)
        tgt = torch.cat([tgt, self.alpha+torch.exp(self.beta)*torch.randn_like(tgt)], dim=-1)
        x = self.transformer(src, tgt)
        x = self.proj_out(x)
        return x, None